{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Skim the following section and skip ahead to the SelectFromModel() section.\n",
    "## This example will be revisited in Chapter 5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_root = '../chapter5/datasets/nsl-kdd'\n",
    "train_file = os.path.join(dataset_root, 'KDDTrain+.txt')\n",
    "test_file = os.path.join(dataset_root, 'KDDTest+.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original KDD dataset feature names obtained from \n",
    "# http://kdd.ics.uci.edu/databases/kddcup99/kddcup.names\n",
    "# http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html\n",
    "\n",
    "header_names = ['duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'attack_type', 'success_pred']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Differentiating between nominal, binary, and numeric features\n",
    "col_names = np.array(header_names)\n",
    "\n",
    "nominal_idx = [1, 2, 3]\n",
    "binary_idx = [6, 11, 13, 14, 20, 21]\n",
    "numeric_idx = list(set(range(41)).difference(nominal_idx).difference(binary_idx))\n",
    "\n",
    "nominal_cols = col_names[nominal_idx].tolist()\n",
    "binary_cols = col_names[binary_idx].tolist()\n",
    "numeric_cols = col_names[numeric_idx].tolist()\n",
    "\n",
    "# training_attack_types.txt maps each of the 22 different attacks to 1 of 4 categories\n",
    "# file obtained from http://kdd.ics.uci.edu/databases/kddcup99/training_attack_types\n",
    "\n",
    "category = defaultdict(list)\n",
    "category['benign'].append('normal')\n",
    "\n",
    "with open('../chapter5/datasets/training_attack_types.txt', 'r') as f:\n",
    "    for line in f.readlines():\n",
    "        attack, cat = line.strip().split(' ')\n",
    "        category[cat].append(attack)\n",
    "\n",
    "attack_mapping = dict((v,k) for k in category for v in category[k])\n",
    "\n",
    "# split into train and test dataframes\n",
    "train_df = pd.read_csv(train_file, names=header_names)\n",
    "train_df['attack_category'] = train_df['attack_type'] \\\n",
    "                                .map(lambda x: attack_mapping[x])\n",
    "train_df.drop(['success_pred'], axis=1, inplace=True)\n",
    "    \n",
    "test_df = pd.read_csv(test_file, names=header_names)\n",
    "test_df['attack_category'] = test_df['attack_type'] \\\n",
    "                                .map(lambda x: attack_mapping[x])\n",
    "test_df.drop(['success_pred'], axis=1, inplace=True)\n",
    "\n",
    "train_Y = train_df['attack_category']\n",
    "train_x_raw = train_df.drop(['attack_category','attack_type'], axis=1)\n",
    "test_Y = test_df['attack_category']\n",
    "test_x_raw = test_df.drop(['attack_category','attack_type'], axis=1)\n",
    "\n",
    "combined_df_raw = pd.concat([train_x_raw, test_x_raw])\n",
    "combined_df = pd.get_dummies(combined_df_raw, columns=nominal_cols, drop_first=True)\n",
    "\n",
    "train_x = combined_df[:len(train_x_raw)]\n",
    "test_x = combined_df[len(train_x_raw):]\n",
    "\n",
    "# Store dummy variable feature names\n",
    "dummy_variables = list(set(train_x)-set(combined_df_raw))\n",
    "\n",
    "# Apply StandardScaler standardization\n",
    "standard_scaler = StandardScaler().fit(train_x[numeric_cols])\n",
    "\n",
    "train_x[numeric_cols] = \\\n",
    "    standard_scaler.transform(train_x[numeric_cols])\n",
    "\n",
    "test_x[numeric_cols] = \\\n",
    "    standard_scaler.transform(test_x[numeric_cols])\n",
    "    \n",
    "train_Y_bin = train_Y.apply(lambda x: 0 if x is 'benign' else 1)\n",
    "test_Y_bin = test_Y.apply(lambda x: 0 if x is 'benign' else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9330  381]\n",
      " [4214 8619]]\n",
      "0.203823633783\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import confusion_matrix, zero_one_loss\n",
    "\n",
    "clf = DecisionTreeClassifier(random_state=0)\n",
    "clf.fit(train_x, train_Y_bin)\n",
    "\n",
    "pred_y = clf.predict(test_x)\n",
    "\n",
    "results = confusion_matrix(test_Y_bin, pred_y)\n",
    "error = zero_one_loss(test_Y_bin, pred_y)\n",
    "\n",
    "print(results)\n",
    "print(error)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Applying SelectFromModel() to find out which features are the most important and should be kept."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_selection import SelectFromModel\n",
    "\n",
    "sfm = SelectFromModel(clf, prefit=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original num features: 119, selected num features: 7\n"
     ]
    }
   ],
   "source": [
    "train_x_new = sfm.transform(train_x)\n",
    "print(\"Original num features: {}, selected num features: {}\"\n",
    "      .format(train_x.shape[1], train_x_new.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.argsort(clf.feature_importances_)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.\tsrc_bytes - 0.739353803124338\n",
      "1.\tservice_ecr_i - 0.07537443951799272\n",
      "2.\tservice_http - 0.056288533270470384\n",
      "3.\tdst_host_same_srv_rate - 0.030188078488584003\n",
      "4.\tdst_bytes - 0.02235518870418086\n",
      "5.\thot - 0.02161534757423723\n",
      "6.\tlogged_in - 0.010299399024875614\n",
      "7.\tservice_ftp_data - 0.007153156775607418\n",
      "8.\tdst_host_srv_count - 0.004581226762306193\n",
      "9.\tprotocol_type_tcp - 0.004392896490987884\n",
      "10.\tduration - 0.0036640994569887724\n",
      "11.\tdst_host_srv_diff_host_rate - 0.0033976880323085775\n",
      "12.\tdst_host_rerror_rate - 0.0033742156274434855\n",
      "13.\tcount - 0.002156246362157632\n",
      "14.\tdst_host_diff_srv_rate - 0.0018632747975075037\n",
      "15.\tservice_private - 0.0015150555545468847\n",
      "16.\tdst_host_srv_serror_rate - 0.0012895009893495672\n",
      "17.\tflag_RSTO - 0.0012092707068345266\n",
      "18.\tdst_host_count - 0.0012087755633389148\n",
      "19.\tservice_smtp - 0.0010333469492228471\n",
      "20.\tflag_S1 - 0.0010025689546886581\n",
      "21.\tflag_REJ - 0.0009382141529412343\n",
      "22.\tservice_finger - 0.000782544179058021\n",
      "23.\tservice_other - 0.0007470303397855699\n",
      "24.\tserror_rate - 0.0006961169739255497\n",
      "25.\tservice_auth - 0.00045270034697481043\n",
      "26.\tdst_host_same_src_port_rate - 0.0003044547983537414\n",
      "27.\tservice_X11 - 0.0002841496445024789\n",
      "28.\tservice_time - 0.0002759699994097465\n",
      "29.\tdiff_srv_rate - 0.00018830601511933\n",
      "30.\tservice_pm_dump - 0.00015902646160888522\n",
      "31.\tservice_telnet - 0.00013317892430146747\n",
      "32.\tdst_host_serror_rate - 0.00013264198009174918\n",
      "33.\tflag_RSTOS0 - 0.00012724109394735437\n",
      "34.\tnum_shells - 0.00012106816370982825\n",
      "35.\trerror_rate - 0.0001175974659305329\n",
      "36.\tsrv_count - 9.567562736593045e-05\n",
      "37.\tservice_tftp_u - 9.56396789549338e-05\n",
      "38.\tservice_urp_i - 9.481733643984882e-05\n",
      "39.\tnum_access_files - 9.469776995874786e-05\n",
      "40.\tservice_tim_i - 7.087262857010326e-05\n",
      "41.\tsrv_rerror_rate - 7.07316857862168e-05\n",
      "42.\tservice_login - 6.363383051761173e-05\n",
      "43.\tflag_S2 - 6.352902301409212e-05\n",
      "44.\tservice_domain_u - 6.143436563685351e-05\n",
      "45.\tflag_SF - 5.461353413799855e-05\n",
      "46.\tservice_imap4 - 4.759568625830915e-05\n",
      "47.\tnum_file_creations - 4.620083964430473e-05\n",
      "48.\tnum_compromised - 4.3431444431814e-05\n",
      "49.\tsrv_diff_host_rate - 4.065901819717204e-05\n",
      "50.\tnum_failed_logins - 3.883335265789425e-05\n",
      "51.\tflag_RSTR - 3.18983663986038e-05\n",
      "52.\tnum_root - 3.185611404878063e-05\n",
      "53.\tservice_gopher - 3.115617125285532e-05\n",
      "54.\tservice_sunrpc - 3.114929234409117e-05\n",
      "55.\tservice_pop_2 - 3.107847506519203e-05\n",
      "56.\tdst_host_srv_rerror_rate - 2.9781762436301826e-05\n",
      "57.\tservice_ftp - 2.945449032161717e-05\n",
      "58.\tsame_srv_rate - 1.8961603652322847e-05\n",
      "59.\tis_guest_login - 3.980676440365067e-06\n",
      "60.\tservice_domain - 1.9639588398599587e-06\n",
      "61.\tprotocol_type_udp - 0.0\n",
      "62.\tservice_Z39_50 - 0.0\n",
      "63.\tservice_vmnet - 0.0\n",
      "64.\tsrv_serror_rate - 0.0\n",
      "65.\tis_host_login - 0.0\n",
      "66.\tnum_outbound_cmds - 0.0\n",
      "67.\tsu_attempted - 0.0\n",
      "68.\troot_shell - 0.0\n",
      "69.\tflag_S0 - 0.0\n",
      "70.\turgent - 0.0\n",
      "71.\twrong_fragment - 0.0\n",
      "72.\tland - 0.0\n",
      "73.\tflag_S3 - 0.0\n",
      "74.\tservice_aol - 0.0\n",
      "75.\tservice_eco_i - 0.0\n",
      "76.\tservice_bgp - 0.0\n",
      "77.\tservice_red_i - 0.0\n",
      "78.\tservice_netbios_ns - 0.0\n",
      "79.\tservice_netbios_ssn - 0.0\n",
      "80.\tservice_netstat - 0.0\n",
      "81.\tservice_nnsp - 0.0\n",
      "82.\tservice_nntp - 0.0\n",
      "83.\tservice_ntp_u - 0.0\n",
      "84.\tservice_pop_3 - 0.0\n",
      "85.\tservice_printer - 0.0\n",
      "86.\tservice_remote_job - 0.0\n",
      "87.\tservice_name - 0.0\n",
      "88.\tservice_rje - 0.0\n",
      "89.\tservice_shell - 0.0\n",
      "90.\tservice_sql_net - 0.0\n",
      "91.\tservice_ssh - 0.0\n",
      "92.\tservice_supdup - 0.0\n",
      "93.\tservice_systat - 0.0\n",
      "94.\tservice_urh_i - 0.0\n",
      "95.\tservice_uucp - 0.0\n",
      "96.\tservice_netbios_dgm - 0.0\n",
      "97.\tservice_mtp - 0.0\n",
      "98.\tservice_courier - 0.0\n",
      "99.\tservice_harvest - 0.0\n",
      "100.\tservice_csnet_ns - 0.0\n",
      "101.\tservice_ctf - 0.0\n",
      "102.\tservice_daytime - 0.0\n",
      "103.\tservice_discard - 0.0\n",
      "104.\tservice_echo - 0.0\n",
      "105.\tservice_uucp_path - 0.0\n",
      "106.\tservice_efs - 0.0\n",
      "107.\tservice_exec - 0.0\n",
      "108.\tservice_hostnames - 0.0\n",
      "109.\tservice_link - 0.0\n",
      "110.\tservice_whois - 0.0\n",
      "111.\tservice_http_2784 - 0.0\n",
      "112.\tservice_http_443 - 0.0\n",
      "113.\tservice_http_8001 - 0.0\n",
      "114.\tservice_iso_tsap - 0.0\n",
      "115.\tservice_klogin - 0.0\n",
      "116.\tservice_kshell - 0.0\n",
      "117.\tservice_ldap - 0.0\n",
      "118.\tflag_SH - 0.0\n"
     ]
    }
   ],
   "source": [
    "for idx, i in enumerate(indices):\n",
    "    print(\"{}.\\t{} - {}\".format(idx, train_x.columns[i], clf.feature_importances_[i]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
